--[[

Example of "coupled" separate encoder and decoder networks, e.g. for sequence-to-sequence networks.

]]--

require 'rnn'

local opt = {}
opt.version = 1.5 -- fixed setHiddenState(0) bug
opt.learningRate = 0.1
opt.hiddenSize = 6
opt.numLayers = 1
opt.vocabSize = 7
opt.seqLen = 7 -- length of the encoded sequence (with padding)
opt.niter = 1000

--[[ Forward coupling: Copy encoder cell and output to decoder LSTM ]]--
function forwardConnect(enc, dec)
   for i=1,#enc.lstmLayers do
      dec.lstmLayers[i]:setHiddenState(0, enc.lstmLayers[i]:getHiddenState(opt.seqLen))
   end
end

--[[ Backward coupling: Copy decoder gradients to encoder LSTM ]]--
function backwardConnect(enc, dec)
   for i=1,#enc.lstmLayers do
      enc.lstmLayers[i]:setGradHiddenState(opt.seqLen, dec.lstmLayers[i]:getGradHiddenState(0))
   end
end

-- Encoder
local enc = nn.Sequential()
enc:add(nn.LookupTableMaskZero(opt.vocabSize, opt.hiddenSize))
enc.lstmLayers = {}
for i=1,opt.numLayers do
   enc.lstmLayers[i] = nn.Sequencer(nn.RecLSTM(opt.hiddenSize, opt.hiddenSize):maskZero())
   enc:add(enc.lstmLayers[i])
end
enc:add(nn.Select(1, -1))

-- Decoder
local dec = nn.Sequential()
dec:add(nn.LookupTableMaskZero(opt.vocabSize, opt.hiddenSize))
dec.lstmLayers = {}
for i=1,opt.numLayers do
   dec.lstmLayers[i] = nn.Sequencer(nn.RecLSTM(opt.hiddenSize, opt.hiddenSize):maskZero())
   dec:add(dec.lstmLayers[i])
end
dec:add(nn.Sequencer(nn.MaskZero(nn.Linear(opt.hiddenSize, opt.vocabSize), 1)))
dec:add(nn.Sequencer(nn.MaskZero(nn.LogSoftMax(), 1)))

local criterion = nn.SequencerCriterion(nn.MaskZeroCriterion(nn.ClassNLLCriterion(),1))

-- Some example data (batchsize = 2) with variable length input and output sequences

-- The input sentences to the encoder, padded with zeros from the left
local encInSeq = torch.Tensor({{0,0,0,0,1,2,3},{0,0,0,4,3,2,1}}):t()
-- The input sentences to the decoder, padded with zeros from the right.
-- Label '6' represents the start of a sentence (GO).
local decInSeq = torch.Tensor({{6,1,2,3,4,0,0,0},{6,5,4,3,2,1,0,0}}):t()

-- The expected output from the decoder (it will return one character per time-step),
-- padded with zeros from the right
-- Label '7' represents the end of sentence (EOS).
local decOutSeq = torch.Tensor({{1,2,3,4,7,0,0,0},{5,4,3,2,1,7,0,0}}):t()

-- the zeroMasks are used for zeroing intermediate RNN states where the zeroMask = 1
-- randomly set the zeroMasks from the input sequence or explicitly
local encZeroMask = math.random() < 0.5 and nn.utils.getZeroMaskSequence(encInSeq) -- auto zeroMask from input sequence
                                         or torch.ByteTensor({{1,1,1,1,0,0,0},{1,1,1,0,0,0,0}}):t():contiguous() -- explicit zeroMask
local decZeroMask = math.random() < 0.5 and nn.utils.getZeroMaskSequence(decInSeq)
                                         or torch.ByteTensor({{0,0,0,0,0,1,1,1},{0,0,0,0,0,0,1,1}}):t():contiguous()

for i=1,opt.niter do
   enc:zeroGradParameters()
   dec:zeroGradParameters()

   -- zero-masking
   enc:setZeroMask(encZeroMask)
   dec:setZeroMask(decZeroMask)
   criterion:setZeroMask(decZeroMask)

   -- Forward pass
   local encOut = enc:forward(encInSeq)
   forwardConnect(enc, dec)
   local decOut = dec:forward(decInSeq)
   --print(decOut)
   local err = criterion:forward(decOut, decOutSeq)

   print(string.format("Iteration %d ; NLL err = %f ", i, err))

   -- Backward pass
   local gradOutput = criterion:backward(decOut, decOutSeq)
   dec:backward(decInSeq, gradOutput)
   backwardConnect(enc, dec)
   local zeroTensor = torch.Tensor(encOut):zero()
   enc:backward(encInSeq, zeroTensor)

   dec:updateParameters(opt.learningRate)
   enc:updateParameters(opt.learningRate)
end
